import numpy as np
import pandas as pd
import torch
import calibration as cal
import matplotlib.pyplot as plt
from dataclasses import dataclass

from src.data import *
from src.utils import *
from src.models import *

from matplotlib import style
plt.style.use('seaborn-v0_8')


def plot_results(trial_results, baseline_results, test_cal_scores, args):

    plt.rcParams["figure.figsize"] = (12,4)
    fig, ax = plt.subplots(1,2)
    for name, results in trial_results.items():
        X, y = results
        ax[0].plot(X, y["ece_1"], label=name)
        ax[1].plot(X, y["ece_2"], label=name)

    ax[0].plot([args.low_quantile/100,args.high_quantile/100], [test_cal_scores["ECE-1"],test_cal_scores["ECE-1"]], "--", label="full", color="k")
    ax[0].legend(loc="lower right", ncol=2, fontsize=10)
    ax[0].set_title("ECE-1")
    ax[0].set_xlabel("Coverage")
    ax[0].set_ylabel("Calibration Error")
    ax[0].set_ylim(0.0, None)

    ax[1].plot([args.low_quantile/100,args.high_quantile/100], [test_cal_scores["ECE-2"],test_cal_scores["ECE-2"]], "--", label="full", color="k")
    ax[1].legend(loc="lower right", ncol=2, fontsize=10)
    ax[1].set_title("ECE-2")
    ax[1].set_xlabel("Coverage")
    ax[1].set_ylim(0.0, None)

    plt.suptitle(args.dataset)
    if args.save_results:
        plt_save_path = "{}/{}_{}_{}.png".format(plt_save_root, args.val_split, args.test_split, args.n_val_samples)
        print("Saving no baselines plot to", plt_save_path)
        plt.savefig(plt_save_path, bbox_inches="tight")
    plt.show()
    
    if args.baselines is None:
        return

    plt.rcParams["figure.figsize"] = (12,4)
    fig, ax = plt.subplots(1,2)
    for name, results in trial_results.items():
        X, y = results
        ax[0].plot(X, y["ece_1"], label=name)
        ax[1].plot(X, y["ece_2"], label=name)

    for name, results in baseline_results.items():
        X, y = results
        ax[0].plot(X, y["ece_1"], "--", label=name)
        ax[1].plot(X, y["ece_2"], "--", label=name)

    ax[0].plot([args.low_quantile/100,args.high_quantile/100], [test_cal_scores["ECE-1"],test_cal_scores["ECE-1"]], "--", label="full", color="k")
    ax[0].legend(loc="lower right", ncol=2, fontsize=10)
    ax[0].set_title("ECE-1")
    ax[0].set_xlabel("Coverage")
    ax[0].set_ylabel("Calibration Error")
    ax[0].set_ylim(0.0, None)

    ax[1].plot([args.low_quantile/100,args.high_quantile/100], [test_cal_scores["ECE-2"],test_cal_scores["ECE-2"]], "--", label="full", color="k")
    ax[1].legend(loc="lower right", ncol=2, fontsize=10)
    ax[1].set_title("ECE-2")
    ax[1].set_xlabel("Coverage")
    ax[1].set_ylim(0.0, None)

    plt.suptitle(args.dataset)
    if args.save_results:
        plt_save_path = "{}/{}_{}_{}_baselines.png".format(plt_save_root, args.val_split, args.test_split, args.n_val_samples)
        print("Saving baselines plot to", plt_save_path)
        plt.savefig(plt_save_path, bbox_inches="tight")
    plt.show()

    
# def get_model_name(args, net_name):

#     model_name = "split_{}_n_{}_net_{}_covwht_{}_seed_{}_scaling_{}_ep_{}".format(
#         args.val_split, args.n_val_samples, net_name, args.coverage_weight, args.seed, args.scaling, args.n_epoch
#     )
#     if args.cov_loss == "quad":
#         model_name += "_beta_{}".format(args.beta)
#     if args.is_wilds and args.noising:
#         model_name += "_noising"
#     return model_name


def train_h(dataset):
    
    mtd = PlattTopCalibrator
    h = mtd(dataset.probs.shape[0], num_bins=10)
    h.train_calibration(dataset.probs, dataset.labels)
    return h


def calc_score(scoring_metric, probs, labels):

    if scoring_metric == "ECE-1":
        return cal.lower_bound_scaling_ce(probs, labels, p=1, debias=False, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "ECE-2":
        return cal.lower_bound_scaling_ce(probs, labels, p=2, debias=True, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "Brier":
        return brier_multi(probs, labels)
    else:
        return ValueError
    
    
class PlattTopCalibrator:

    def __init__(self, num_calibration, num_bins):
        self._num_calibration = num_calibration
        self._num_bins = num_bins

    def train_calibration(self, probs, labels):
        assert(len(probs) >= self._num_calibration)
        predictions = np.argmax(probs, -1)
        top_probs = np.max(probs, -1)
        correct = (predictions == labels)
        self._platt, self.clf = get_platt_scaler(
            top_probs, correct, True)

    def calibrate(self, probs):
        return self._platt(np.max(probs, -1))
    

def get_platt_scaler(model_probs, labels, get_clf=False):
    clf = LogisticRegression(C=1e10, solver='lbfgs', fit_intercept=False)
    eps = 1e-12
    model_probs = model_probs.astype(dtype=np.float64)
    model_probs = np.expand_dims(model_probs, axis=-1)
    model_probs = np.clip(model_probs, eps, 1 - eps)
    model_probs = np.log(model_probs / (1 - model_probs))
    clf.fit(model_probs, labels)
    def calibrator(probs):
        x = np.array(probs, dtype=np.float64)
        x = np.clip(x, eps, 1 - eps)
        x = np.log(x / (1 - x))
        x = x * clf.coef_[0] + clf.intercept_
        output = 1 / (1 + np.exp(-x))
        return output
    if get_clf:
        return calibrator, clf
    return calibrator


@dataclass
class Group:
    X: np.array
    y: np.array
    
    
def y_generator(probs, w=1.0, b=0.0, eps=1e-8, sigmoid=True):
    # x = torch.maximum(probs, 1-probs)
    x = probs
    x = x.clamp(min=eps, max=(1.0 - eps))
    x = torch.log(x / (1 - x))
    x = x * w + b
    if sigmoid:
        x = torch.sigmoid(x)
    return x


def load_dataset(args, split="val", return_groups=False):

    # mu_1, sig_1 = 0.0, 0.25
    mu_1, sig_1 = args.mu_1, args.sig_1
    w_1, b_1 = args.w_1, args.b_1
    
    # mu_2, sig_2 = 1.0, 0.1
    mu_2, sig_2 = args.mu_2, args.sig_2
    w_2, b_2 = args.w_2, args.b_2
    
    if split == "val":
        no_split_1 = int(args.beta*args.total_data)
        no_split_2 = args.total_data-no_split_1
    elif split == "test":
        no_split_1 = int(args.beta*args.n_test_samples)
        no_split_2 = args.n_test_samples - no_split_1
    else:
        raise ValueError

    x_1 = torch.sigmoid(
        torch.Tensor(np.random.uniform(-5, 5, no_split_1))
    )
    
    y_p = y_generator(torch.Tensor(x_1), w_1, b_1).numpy()
    f_x = x_1.numpy()
    x_2 = np.random.normal(mu_1, sig_1, no_split_1)
    y = np.zeros(no_split_1)
    for i in range(no_split_1):
        y[i] = np.random.binomial(1, p=y_p[i])

    X = np.vstack([x_1.numpy(), x_2]).T

    group_1 = Group(X, y)
    
    logits_1 = np.array([1-f_x, f_x]).T
    
    y_hat = np.argmax(logits_1, -1)
    y_correct = np.array((y_hat==y), dtype=int)
    
    acc = np.sum(y_correct)/len(y_correct)
    
    print("g1 ece1", calc_score("ECE-1", np.max(logits_1, -1), y_correct))
    print("g1 ece2", calc_score("ECE-2", np.max(logits_1, -1), y_correct))
    print("g1 acc", acc)

    x_1 = torch.sigmoid(
        torch.Tensor(np.random.uniform(-5, 5, no_split_2))
    )
    
    y_p = y_generator(torch.Tensor(x_1), w_2, b_2).numpy()
    f_x = x_1.numpy()
    x_2 = np.random.normal(mu_2, sig_2, no_split_2)
    y = np.zeros(no_split_2)
    for i in range(no_split_2):
        y[i] = np.random.binomial(1, p=y_p[i])

    X = np.vstack([x_1.numpy(), x_2]).T

    group_2 = Group(X, y)
    
    logits_2 = np.array([1-f_x, f_x]).T
    
    y_hat = np.argmax(logits_2, -1)
    y_correct = np.array((y_hat==y), dtype=int)
    
    acc = np.sum(y_correct)/len(y_correct)
    
    print()

    print("g2 ece1", calc_score("ECE-1", np.max(logits_2, -1), y_correct))
    print("g2 ece2", calc_score("ECE-2", np.max(logits_2, -1), y_correct))
    print("g2 acc", acc)
    

    dataset = Dataset(
        logits=np.vstack([logits_1, logits_2]),
        labels=np.concatenate([group_1.y, group_2.y]),
        features=np.vstack([group_1.X, group_2.X]),
        probs_only=True
    )
    
    dset = torch.utils.data.TensorDataset(
        torch.Tensor(dataset.features), 
        torch.Tensor(dataset.logits), 
        torch.Tensor(dataset.labels)
    )
    data_loader = torch.utils.data.DataLoader(dset, batch_size=args.batch_size, shuffle=True,
                                             pin_memory=True, 
                                              # num_workers=2
                                             )

    plt.rcParams["figure.figsize"] = (4,3)
    if split == "val":
        s_size = 1.5
        opac = 0.5
    else:
        s_size = 1.0
        opac = 0.15
    plt.scatter(group_1.X[:,0], group_1.X[:,1], alpha=opac, s=s_size)
    plt.scatter(group_2.X[:,0], group_2.X[:,1], alpha=opac, s=s_size)
    plt.xlabel(r"$x_1$")
    plt.ylabel(r"$x_2$")
    plt.show()

    if return_groups:
        return dataset, data_loader, group_1, group_2
    else:
        return dataset, data_loader

    
    
